import os


save_folder_own = "path/to/modela/prediction/"

save_folder_others =  "path/to/modela/prediction/"

save_folder =  "path/to/save_path/"

if not os.path.exists(save_folder):
    os.makedirs(save_folder)


import sys

folder = int(sys.argv[1])
gpu_index = sys.argv[2]

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]=gpu_index


if folder == 0:
    folder_test_set = list(range(0,14))+list(range(200, 204)) + [249]
elif folder == 1:
    folder_test_set = list(range(28, 58))+list(range(204, 208))+[250]
elif folder == 2:
    folder_test_set = list(range(72, 110))+list(range(208, 212))+[275]
elif folder == 3:
    folder_test_set = list(range(116, 140))+list(range(212, 216))+[276]
elif folder == 4:
    folder_test_set = list(range(160,188))+list(range(216, 220))+[282]
elif folder == 5:
    folder_test_set = [188]+list(range(220, 224))+[283]
elif folder == 6:
    folder_test_set = [199] +list(range(224, 228))
elif folder == 7:
    folder_test_set = list(range(200, 230)) +list(range(228, 230))+[284]
elif folder == 8:
    folder_test_set = list(range(158, 160))+list(range(251,252)) + list(range(268,275))+ list(range(262,268))+list(range(105, 110))#252-258
elif folder == 9:
    folder_test_set = list(range(189,199))+list(range(277, 290))+[256]+[253]+list(range(260,262))+[252]
elif folder == 10:
    folder_test_set = list(range(140, 158))
elif folder == 11:
    folder_test_set = list(range(258,260))+list(range(254,256))+[257]+list(range(58, 68))+list(range(143, 146))
elif folder == 12:
    folder_test_set = list(range(110, 115)) + list(range(237,245))+list(range(146, 150))
elif folder == 13:
    folder_test_set = list(range(68,72))+[115]  + list( range(275,277)) +list(range(245,251))+list(range(230, 237)) +list(range(150, 154))
elif folder == 14:
    folder_test_set = list(range(14,28))+list(range(154, 158))+list(range(48, 58))
elif folder == 15:
    folder_test_set = list(range(48, 58))+list(range(285, 290))

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import decord
from decord import VideoReader, cpu
import random
import torch
from torch.utils.data.dataloader import default_collate
from PIL import Image
from typing import Dict, Optional, Sequence
import transformers
import pathlib
import json
import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
import copy
import math
from torchvision import transforms
import pdb
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel
import pytorch_lightning as pl
import itertools

import warnings
from typing import Any, List, Optional, Tuple, Union

import torch.distributed as dist
import torch.utils.checkpoint
import transformers

import sys
sys.path.append('InternVL/internvl_chat')

from internvl.conversation import get_conv_template
from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from internvl.model.phi3.modeling_phi3 import Phi3ForCausalLM
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
                          LlamaTokenizer, Qwen2ForCausalLM)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, logging


from datasets.video_instruction_dataset_inference import dynamic_preprocess
from datasets.video_instruction_dataset_inference import Video_Instruct_Dataset_Inference
            
class VideoAnomalyDetectionModelInference(pl.LightningModule):
    def __init__(self, model, tokenizer, optimizer_instruct, model_instruct, new_qs, generation_config, epochs=5):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer_instruct = optimizer_instruct
        self.model_instruct = model_instruct
        self.new_qs = new_qs
        self.generation_config = generation_config

    def forward(self, pixel_values, num_patches_list, own_opinion,other_opinion, visual_features=None):
        
        self.model.eval()
        # Get batch and frame size
        batch_size = pixel_values.size(0)
        frame_size = pixel_values.size(1)
    
        # Reshape pixel_values for processing
        pixel_values = pixel_values.to(torch.bfloat16)
 
        pixel_values = torch.reshape(pixel_values, (batch_size * frame_size, pixel_values.size(2), pixel_values.size(3), pixel_values.size(4)))

        # Create video frame prefix for each frame
        video_prefix = ''.join([f'Frame{i + 1}: <image>\n' for i in range(frame_size)])

        instance = own_opinion
        if 'Answers to Prompt Questions:' in instance['response']:
            pos = instance['response'].find('Answers to Prompt Questions:')+len('Answers to Prompt Questions:')+1
            own_analysis =instance['response'][pos:].lstrip('\n').rstrip('\n') + '\n'
            if instance['score'] == 0:
                own_analysis += 'Conclusion: No, there is no anomaly.\n'
            else:
                own_analysis += 'Conclusion: Yes, there is an anomaly.\n'


        instance = other_opinion
        if 'Answers to Prompt Questions:' in instance['response']:
            pos = instance['response'].find('Answers to Prompt Questions:')+len('Answers to Prompt Questions:')+1
            other_analysis =instance['response'][pos:].lstrip('\n').rstrip('\n') + '\n'
            if instance['score'] == 0:
                other_analysis += 'Conclusion: No, there is no anomaly.\n'
            else:
                other_analysis += 'Conclusion: Yes, there is an anomaly.\n'




        model_instruct = ("You are the model.\n"
        "** Model Descriptions: **\n"
        "You are designed to do binary classification. The input is a sequence of video frames for identifying whether there is an anomaly in the video. You need to output the class label, i.e., an integer in the set {0, 1}. 0 represents normal video, and 1 represents abnormal video. Please follow the instruction below to make a conclusion.\n"
        "** Instruction: **\n"
        "There are two analyses below, please conclude your answer to 'Is there any anomaly in the video?' in 'Yes, there is an anomaly' or 'No, there is no anomaly' by watching the video carefully and pondering whether these two opinions are consistent.\n"
        "** Opinion 1: **\n"
        f'{other_analysis}'
        "** Opinion 2: **\n"
        f'{own_analysis}'
        "** Input: **\n"
        "[$Data]\n"
        "Please give your output strictly in the following format:\n"
        "```\n"
        "New Analysis and Conclusion: [Please provide a new analysis with these two opinions and make an conclusion.]\n"
        "Output:\n"
        "[ONLY the integer class label]\n"
        "```\n"
        "Please ONLY reply according to this format, don't give me any other words.")

        question = model_instruct.replace('$Data', video_prefix)

        questions = [question] * batch_size
        num_patches_list_real = [frame_size] * batch_size

        if visual_features is not None:
            responses = self.batch_chat_with_cache(pixel_values, visual_features, num_patches_list=num_patches_list_real, questions=questions, generation_config=generation_config)
        else:
            responses = self.model.batch_chat(self.tokenizer, pixel_values, num_patches_list=num_patches_list_real, questions=questions, generation_config=generation_config)
 
        predict_labels = []
        for response in responses:
            split_response = response.split('Output')
            response = split_response[0]
            if '0' in split_response[-1]:
                predict_labels.append(0)
            else:
                predict_labels.append(1)
        predict_labels = torch.tensor(predict_labels).to(pixel_values.device)

        return predict_labels, response

    
    
    def batch_chat_with_cache(self, pixel_values, visual_features, questions, generation_config, num_patches_list=None,
                   history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
                   IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
        if history is not None or return_history:
            print('Now multi-turn chat is not supported in batch_chat.')
            raise NotImplementedError

        img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        self.model.img_context_token_id = img_context_token_id

        queries = []
        for idx, num_patches in enumerate(num_patches_list):
            question = questions[idx]
            if pixel_values is not None and '<image>' not in question:
                question = '<image>\n' + question
            template = get_conv_template(self.model.template)
            template.system_message = self.model.system_message
            template.append_message(template.roles[0], question)
            template.append_message(template.roles[1], None)
            query = template.get_prompt()

            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.model.num_image_token * num_patches + IMG_END_TOKEN
            query = query.replace('<image>', image_tokens, 1)
            queries.append(query)

        self.tokenizer.padding_side = 'left'
        model_inputs = self.tokenizer(queries, return_tensors='pt', padding=True)
        input_ids = model_inputs['input_ids'].cuda()
        attention_mask = model_inputs['attention_mask'].cuda()
        eos_token_id = self.tokenizer.convert_tokens_to_ids(template.sep)
        generation_config['eos_token_id'] = eos_token_id
        
        generation_output = self.model.generate(
            pixel_values=pixel_values,
            visual_features=visual_features,
            input_ids=input_ids,
            attention_mask=attention_mask,
            **generation_config
        )
        responses = self.tokenizer.batch_decode(generation_output, skip_special_tokens=True)
        responses = [response.split(template.sep)[0].strip() for response in responses]

        return responses


path = 'OpenGVLab/InternVL2_5-8B'

# Load model
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    use_flash_attn=True,
    cache_dir='./cache'
).eval()

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, cache_dir='./cache')


generation_config = dict(
    num_beams=1,
    max_new_tokens=1024,
    do_sample=False,
)


# Create the datasets
test_dataset = Video_Instruct_Dataset_Inference(vis_root='data/ucf/frames/', ann_root='data/UCF_Eval.json')

# Create the data loaders
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=16)


file = open('model_instruct_corrected_reasonable.txt','r')
model_instruct =  file.read()
file.close()




iter_test_loader = iter(test_loader)

fps = 30
t_window = 10
sampling_rate = 16
snippet_len = 8
image_transform = test_dataset.transform

dict_all = {}
count = -1

all_qs = new_qs
    

# Initialize Lightning Model
lightning_model = VideoAnomalyDetectionModelInference(
    model=model, tokenizer=tokenizer, optimizer_instruct=optimizer_instruct,
    model_instruct=model_instruct, new_qs=all_qs, generation_config=generation_config
).eval().cuda()


while True:

    dict_single ={}
    try:
        video_path, video_name, n_frms  = next(iter_test_loader)

        count+=1

        if not count in folder_test_set:
            continue

        print("testing video "+video_name[0])

        import os.path

        if os.path.exists(save_folder+"{}.json".format(video_name[0])):
            continue

 

        start, end = 0, n_frms.item()
        vlen = n_frms.item()


        opinion1_folder =  save_folder_own+"{}.json".format(video_name[0])
        f = open(opinion1_folder, 'r')
        opinion1 = json.load(f)

        opinion2_folder =  save_folder_others+"{}.json".format(video_name[0])
        f = open(opinion2_folder, 'r')
        opinion2 = json.load(f)

        
        indices = [i for i in range(start, end, sampling_rate)]

        for ind in indices:
            snippet_start = max(ind - 1/2*fps*t_window, start)
            snippet_end = min(ind+1/2*fps*t_window, end)

            snippet_indices = np.arange(snippet_start, snippet_end, (snippet_end-snippet_start)/snippet_len).astype(int).tolist()
            frame_no_list=['{:06d}'.format(i)+'.jpg' for i in snippet_indices]
            pixel_values_list, num_patches_list = [], []
            for frame_name in frame_no_list:
    
                frame_path = video_path[0] + '/' + frame_name
                frame = Image.open(frame_path).convert('RGB')
                
                stitched_image = dynamic_preprocess([frame], grid_size=1)

                # Apply transform to the stitched image
                pixel_values = [image_transform(stitched_image)]
                pixel_values = torch.stack(pixel_values)

                # Store the pixel values and number of patches
                pixel_values_list.append(pixel_values)
                num_patches_list.append(pixel_values.shape[0])

            pixel_values = torch.cat(pixel_values_list).unsqueeze(0)
            pixel_values = pixel_values.to(torch.bfloat16).cuda()


            own_opinion = opinion1[str(ind)] 
            other_opinion = opinion2[str(ind)] 
            visual_features = None
            predict_labels, response = lightning_model.forward(pixel_values, num_patches_list,own_opinion,other_opinion, visual_features)


            dict_single[ind]={'start':snippet_start,'end':snippet_end, 'score':predict_labels.item(),'response':response}

        with open(save_folder+"{}.json".format(video_name[0]),"w") as outfile:
            
            json.dump(dict_single, outfile)

    except (StopIteration):
        break
 